{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "0894a735",
   "metadata": {},
   "outputs": [],
   "source": [
    "from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig\n",
    "import torch\n",
    "\n",
    "TORCH_DTYPE = 'bfloat16'\n",
    "nf4_config = BitsAndBytesConfig(\n",
    "    load_in_4bit=True,\n",
    "    bnb_4bit_quant_type='nf4',\n",
    "    bnb_4bit_use_double_quant=True,\n",
    "    bnb_4bit_compute_dtype=getattr(torch, TORCH_DTYPE)\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "6daae66f",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "518884a9a63e4d37a8b4e4e23379ad97",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Downloading (…)okenizer_config.json:   0%|          | 0.00/762 [00:00<?, ?B/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "1ca5d122c2ee41c9986c28bc5f385201",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Downloading tokenizer.model:   0%|          | 0.00/500k [00:00<?, ?B/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "e0fb3855a0a24a64947a8d672647fe68",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Downloading (…)/main/tokenizer.json:   0%|          | 0.00/1.84M [00:00<?, ?B/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "3feca8c9fe394e26b9283dda6a45e054",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Downloading (…)cial_tokens_map.json:   0%|          | 0.00/414 [00:00<?, ?B/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "tokenizer = AutoTokenizer.from_pretrained('mesolitica/malaysian-llama2-13b-32k-instructions')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "0da5df04",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "919862dcdea649feb4d0b04453434dd6",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Downloading (…)lve/main/config.json:   0%|          | 0.00/611 [00:00<?, ?B/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "ef7c7cfca2e14f74a4c7483d7fbc3854",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Downloading (…)fetensors.index.json:   0%|          | 0.00/29.9k [00:00<?, ?B/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "cab1b9163f7246e188ff429bbf1523af",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Downloading shards:   0%|          | 0/3 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "b70d0d71cdab4c34ab4e3437fd187d92",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Downloading (…)of-00003.safetensors:   0%|          | 0.00/9.95G [00:00<?, ?B/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "834588ef5696456a9529fc12c65b8df5",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Downloading (…)of-00003.safetensors:   0%|          | 0.00/9.90G [00:00<?, ?B/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "b72740cbd6c846a581c99ea817026115",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Downloading (…)of-00003.safetensors:   0%|          | 0.00/6.18G [00:00<?, ?B/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "ba537735f9964fbb9b66725fc8cb6c58",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "377a1dd67e5946a0a0bea34fd0b54db0",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Downloading (…)neration_config.json:   0%|          | 0.00/183 [00:00<?, ?B/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "model = AutoModelForCausalLM.from_pretrained(\n",
    "    'mesolitica/malaysian-llama2-13b-32k-instructions',\n",
    "    use_flash_attention_2 = True,\n",
    "    quantization_config = nf4_config\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "f18b9442",
   "metadata": {},
   "outputs": [],
   "source": [
    "def parse_llama_chat(messages):\n",
    "\n",
    "    system = messages[0]['content']\n",
    "    user_query = messages[-1]['content']\n",
    "\n",
    "    users, assistants = [], []\n",
    "    for q in messages[1:-1]:\n",
    "        if q['role'] == 'user':\n",
    "            users.append(q['content'])\n",
    "        elif q['role'] == 'assistant':\n",
    "            assistants.append(q['content'])\n",
    "\n",
    "    texts = [f'<s>[INST] <<SYS>>\\n{system}\\n<</SYS>>\\n\\n']\n",
    "    for u, a in zip(users, assistants):\n",
    "        texts.append(f'{u.strip()} [/INST] {a.strip()} </s><s>[INST] ')\n",
    "    texts.append(f'{user_query.strip()} [/INST]')\n",
    "    prompt = ''.join(texts).strip()\n",
    "    return prompt"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "07881cfc",
   "metadata": {},
   "outputs": [],
   "source": [
    "messages = [\n",
    "    {'role': 'system', 'content': 'awak adalah AI yang mampu jawab segala soalan'},\n",
    "    {'role': 'user', 'content': 'kwsp tu apa'}\n",
    "]\n",
    "prompt = parse_llama_chat(messages)\n",
    "inputs = tokenizer([prompt], return_tensors='pt', add_special_tokens=False).to('cuda')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "67d3c231",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'<s> [INST] <<SYS>>\\nawak adalah AI yang mampu jawab segala soalan\\n<</SYS>>\\n\\nkwsp tu apa [/INST] Kumpulan Wang Simpanan Pekerja (KWSP) merupakan sebuah badan berkanun yang ditubuhkan di bawah Akta Kumpulan Wang Simpanan Pekerja 1991 untuk menguruskan simpanan persaraan pekerja dan melaksanakan program persaraan kepada pekerja dan tanggungan mereka. Visi KWSP adalah untuk menjadi pihak berkuasa persaraan terkemuka dalam membina dan membangunkan simpanan untuk masa depan yang lebih baik. </s>'"
      ]
     },
     "execution_count": 6,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "generate_kwargs = dict(\n",
    "    inputs,\n",
    "    max_new_tokens=1024,\n",
    "    top_p=0.95,\n",
    "    top_k=50,\n",
    "    temperature=0.9,\n",
    "    do_sample=True,\n",
    "    num_beams=1,\n",
    ")\n",
    "r = model.generate(**generate_kwargs)\n",
    "tokenizer.decode(r[0])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "4efaf03f",
   "metadata": {},
   "outputs": [],
   "source": [
    "messages = [\n",
    "    {'role': 'system', 'content': 'awak adalah AI yang mampu jawab segala soalan'},\n",
    "    {'role': 'user', 'content': 'awat malaysia ada jabatan koko, malaysia bukan buat keluaq koko banyak pun'}\n",
    "]\n",
    "prompt = parse_llama_chat(messages)\n",
    "inputs = tokenizer([prompt], return_tensors='pt', add_special_tokens=False).to('cuda')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "bdc41dda",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "<s> [INST] <<SYS>>\n",
      "awak adalah AI yang mampu jawab segala soalan\n",
      "<</SYS>>\n",
      "\n",
      "awat malaysia ada jabatan koko, malaysia bukan buat keluaq koko banyak pun [/INST] Jabatan Koko Malaysia ialah agensi kerajaan yang bertanggungjawab untuk pembangunan dan pengawalseliaan industri koko Malaysia. Sementara itu, pengeluaran koko adalah sebahagian kecil daripada ekonomi Malaysia. Antara faktor yang menyumbang kepada ini ialah kos tenaga buruh yang tinggi, kekurangan tanah untuk penanaman koko, dan kekurangan penglibatan dalam sektor hiliran. Walau bagaimanapun, industri koko Malaysia mempunyai potensi besar untuk berkembang dan memacu pertumbuhan ekonomi. </s>\n"
     ]
    }
   ],
   "source": [
    "generate_kwargs = dict(\n",
    "    inputs,\n",
    "    max_new_tokens=1024,\n",
    "    top_p=0.95,\n",
    "    top_k=50,\n",
    "    temperature=0.9,\n",
    "    do_sample=True,\n",
    "    num_beams=1,\n",
    ")\n",
    "r = model.generate(**generate_kwargs)\n",
    "print(tokenizer.decode(r[0]))"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
